#ifndef STATSxx_MACHINE_LEARNING_RBM_HPP
#define STATSxx_MACHINE_LEARNING_RBM_HPP

// STL
#include <iostream> // std::istream, std::ostream
#include <utility>  // std::pair<>, std::make_pair()
#include <vector>   // std::vector<>

// Boost
#include <boost/serialization/serialization.hpp> // boost::serialization::

// jScience
#include "jScience/linalg.hpp" // Matrix<>, Vector<>


// TODO: I would like to get away from the static ordering of (0 == binary; 1 == continuous; 2 == ReLU; 3 == softplus) and be able to pass them or something (a lot more flexibility)


namespace machine_learning
{
    //=========================================================
    // RESTRICTED BOLTZMANN MACHINE
    //=========================================================
    //
    // NOTE: the "types" of units (visible or hidden) allowed are:
    //
    //     0 == binary
    //     1 == continuous
    //     2 == ReLU
    //     3 == softplus
    //
    class RBM
    {

    public:
        
        RBM();
        
        RBM(
            const int nv_,   
            const int nh_,    
            // -----
            const int vtype_,
            const int htype_  
            );
        
        ~RBM();
        
        void train(
                   const int                  nepoch,
                   const int                  nbatches,
                   // -----
                   const int                  K_start,            // starting number of CD steps
                   const int                  K_end,              // ending number
                   const double               K_rate,             // rate at which K_start transitions to K_end
                   // -----
                   const bool                 sample_v,           // whether to sample v during reconstructions --- this is correct and may lead to better density models (Hinton, practical guide), but not reduces sampling noise thus allowing faster learning
                   const bool                 sample_hdata,       // whether to sample h for the positive statistics --- true is closer to the mathematical model of an RBM
                   // -----
                   double                     lr,
                   const std::vector<double> &lr_dot,
                   const std::vector<double> &lr_adj,
                   const double               lr_min,
                   const double               lr_max,
                   // ---
                   const double               momentum_start,
                   const double               momentum_end,
                   const double               momentum_rate,
                   const std::vector<double> &momentum_dot,
                   const std::vector<double> &momentum_adj,
                   // -----
                   const double               weight_penalty,
                   // -----
                   const int                  free_energy_npts,   // number of training points to (randomly) select for free energy
                   const int                  free_energy_nepoch, // evaluate the free energy every number of epochs
                   const int                  free_energy_window, // calculate the slope of the free energy over this window
                   const double               convg_criterion_max_inc_max_w,
                   const int                  convg_criterion_nno_improvement,
                   // -----
                   const Matrix<double>      &X,                  // training set
                   const Matrix<double>      &X_val,              // validation set
                   // -----
                   const std::string          prefix
                   );
        
        std::pair<
                  Vector<double>,
                  Vector<double>
                  > reconstruct(
                                const Vector<double> &_v
                                ) const;
        
        std::pair<
                  Matrix<double>,
                  Matrix<double>
                  > reconstruct(
                                const Matrix<double> &V
                                ) const;
        
        Vector<double> gen_sample(
                                  const int  niter,
                                  // -----
                                  const bool sample_v
                                  ) const;
        
        Vector<double> v_to_h(
                              const Vector<double> &v
                              ) const;
        
        Matrix<double> v_to_h(
                              const Matrix<double> &V
                              ) const;
        
        Vector<double> h_to_v(
                              const Vector<double> &h
                              ) const;
        
/*
        Matrix<double> h_to_v(
                              const Matrix<double> &H
                              ) const;
*/
 
        double free_energy(
                           const Vector<double> &v
                           ) const;
        
        std::vector<double> free_energy(
                                        const Matrix<double> &V
                                        ) const;
        
        int get_nv() const;
        int get_nh() const;
        
        Matrix<double> get_W() const;
        Vector<double> get_a() const;
        Vector<double> get_b() const;
        
        // TODO: tmp because I wanted to assign easily ...
        // ... probably better to either provide subroutines to set_W, set_a, etc. --- or do this from a constructor
        Matrix<double> W;
        Vector<double> a;
        Vector<double> b;
        
 
    private:
        
        int nv;    // number of visible units
        int nh;    // ... hidden
        // ---
        int vtype; // 0 == binary; 1 == continuous; 2 == ReLU; 3 == softplus
        int htype; // ... same
        
        //Matrix<double> W;
        //Vector<double> a;
        //Vector<double> b;
        
        
        double free_energy_visible(
                                   const Vector<double> &v
                                   ) const;
        double free_energy_hidden(
                                  const Vector<double> &v
                                  ) const;
        double free_energy_const() const;
        
        Vector<double> calculate_p(
                                   const Vector<double> &x,
                                   const int             type
                                   );
        
        Vector<double> sample(
                              const Vector<double> &x,
                              const Vector<double> &px,
                              const int             type
                              );
        
        // (Boost) SERIALIZATION 
        friend class boost::serialization::access;
        
        template<class Archive>
        void serialize(
                       Archive            &ar,
                       const unsigned int  version
                       )
        {
            ar & this->nv;
            ar & this->nh;
            // ---
            ar & this->vtype;
            ar & this->htype;
            
            ar & this->W;
            ar & this->a;
            ar & this->b;
        }
    };
    
}

#include "statsxx/machine_learning/RBM/calculate_p.cpp"
#include "statsxx/machine_learning/RBM/free_energy.cpp"
#include "statsxx/machine_learning/RBM/get_a.cpp"
#include "statsxx/machine_learning/RBM/get_b.cpp"
#include "statsxx/machine_learning/RBM/get_nv.cpp"
#include "statsxx/machine_learning/RBM/get_nh.cpp"
#include "statsxx/machine_learning/RBM/get_W.cpp"
#include "statsxx/machine_learning/RBM/h_to_v.cpp"
#include "statsxx/machine_learning/RBM/reconstruct.cpp"
#include "statsxx/machine_learning/RBM/RBM.cpp"
#include "statsxx/machine_learning/RBM/sample.cpp"
#include "statsxx/machine_learning/RBM/train.cpp"
#include "statsxx/machine_learning/RBM/v_to_h.cpp"


#endif
